package rpc

import (
	"context"
	"errors"
	"gitee.com/zackeus/go-boot/tools/errorx"
	"gitee.com/zackeus/go-zero/core/logx"
	"gitee.com/zackeus/go-zero/core/stores/sqlx"
	"google.golang.org/grpc"
	codes2 "google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
)

func errHandle(ctx context.Context, err error) error {
	if err == nil {
		return err
	}

	isLog := true
	/* err类型 */
	causeErr := errorx.Cause(err)

	if e, ok := causeErr.(errorx.ErrorR); ok {
		/* 自定义错误类型 转成grpc err */
		err = status.Error(codes2.Code(e.Code()), e.Error())
	} else if errors.Is(causeErr, sqlx.ErrNotFound) {
		/* model not found */
		isLog = false
		err = status.Error(codes2.NotFound, "数据不存在")
	} else if e, ok := status.FromError(causeErr); ok {
		/* grpc error */
		switch e.Code() {
		case codes2.NotFound, codes2.DataLoss:
			/* 数据不存在 */
			isLog = false
		default:
			break
		}
	}
	if isLog {
		logx.WithContext(ctx).Errorf("【RPC-SRV-ERR】 %+v", err)
	}
	return err
}

// ErrUnaryInterceptor rpc 一元异常拦截器
func ErrUnaryInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
	resp, err = handler(ctx, req)
	return resp, errHandle(ctx, err)
}

// ErrStreamInterceptor rpc stream 异常拦截器
func ErrStreamInterceptor(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
	err := handler(srv, ss)
	return errHandle(ss.Context(), err)
}
